#!/usr/bin/env python3
"""
Bayesian Optimization for DAMCTS Hyperparameter Tuning

This script uses Bayesian optimization to find the best hyperparameters
for your DAMCTS algorithm across different environments.
"""

import faulthandler;

faulthandler.enable()

import math
import copy
import gym
import random
import numpy as np
import statistics
import time
import pickle
from pathlib import Path

# Bayesian optimization
try:
    from skopt import gp_minimize
    from skopt.space import Real, Integer, Categorical
    from skopt.utils import use_named_args
    from skopt.acquisition import gaussian_ei
    import matplotlib.pyplot as plt
    from skopt.plots import plot_convergence, plot_objective
except ImportError:
    print("ERROR: Please install scikit-optimize:")
    print("pip install scikit-optimize matplotlib")
    exit(1)

# Import your DAMCTS implementation components
import improved_walker2d
import improved_ant
import improved_humanoid
from SnapshotENV import SnapshotEnv


# ========================================
# DAMCTS Implementation (modified for tuning)
# ========================================

class TunableDAMCTSNode:
    """DAMCTS Node with tunable hyperparameters"""

    def __init__(self, parent, action, eps_net_func, env, env_name, action_dim, lo, hi, hyperparams):
        self.parent = parent
        self.action = None if action is None else np.asarray(action, dtype=np.float32)
        self.children = {}
        self.visit_count = 0
        self.value_sum = 0.0
        self.value_sum_power = 0.0

        self.env_name = env_name
        self._action_dim = int(action_dim)
        self.lo, self.hi = float(lo), float(hi)
        self.hyperparams = hyperparams

        # Environment interaction
        if parent is None:
            self.snapshot = None
            self.obs = None
            self.immediate_reward = 0.0
            self.is_done = False
        else:
            snapshot, obs, reward, done, info = env.get_result(parent.snapshot, self.action)
            self.snapshot = snapshot
            self.obs = obs
            # Use tunable reward shaping
            r_offset = hyperparams['reward_offset']
            self.immediate_reward = max(0.01, (reward + r_offset) * hyperparams['reward_scaling'])
            self.is_done = bool(done)

        self.eps_net_func = eps_net_func

    @staticmethod
    def _act_key(a, ndigits=6):
        return tuple(np.round(a.astype(np.float32), decimals=ndigits))

    def is_root(self):
        return self.parent is None

    def get_action_dim(self):
        return self._action_dim

    def get_mean_value(self):
        return self.value_sum / self.visit_count if self.visit_count > 0 else 0.0

    def get_power_mean_value(self):
        if self.visit_count == 0:
            return 0.0
        power = self.hyperparams['power']
        return (self.value_sum_power / self.visit_count) ** (1.0 / power)

    def epsilon_level(self):
        n = max(self.visit_count, 1)
        d = self.get_action_dim()
        k = 1
        epsilon_1 = self.hyperparams['epsilon_1']
        beta = self.hyperparams['beta']

        while True:
            eps_k = epsilon_1 * (2.0 ** (-(k - 1) / (d + 2.0 * beta)))
            if d <= 4:
                per_dim = max(2, int(np.ceil((1.0 / max(eps_k, 1e-6)) ** (1.0 / d))))
                size = min(25, per_dim) ** d
            else:
                size = int(np.clip((1.0 / max(eps_k, 1e-6)) ** 2, 100, 2000))
            if n <= size * size:
                return k, eps_k, size
            k += 1

    def ucb_score(self, child_node, parent_visits):
        if child_node is None or child_node.visit_count == 0:
            return float("inf")

        pm = child_node.get_power_mean_value()

        # Tunable discretization bias
        _, eps_k, _ = self.epsilon_level()
        eps_bonus = self.hyperparams['l_holder'] * (eps_k ** self.hyperparams['beta'])

        # Tunable polynomial exploration bonus
        if parent_visits > 0 and child_node.visit_count > 0:
            bonus_c = self.hyperparams['bonus_c']
            parent_exp = self.hyperparams['parent_exp']
            child_exp = self.hyperparams['child_exp']
            cb = bonus_c * (parent_visits ** parent_exp) / (child_node.visit_count ** child_exp)
        else:
            cb = float("inf")

        return pm + eps_bonus + cb

    def get_child(self, action):
        key = self._act_key(action)
        return self.children.get(key, None)

    def tree_policy(self, env):
        node = self
        while True:
            if node.is_done:
                return node

            # Build current epsilon net
            _, eps_k, _ = node.epsilon_level()
            actions = node.eps_net_func(eps_k)

            # Expand if any discretized action hasn't been tried
            for a in actions:
                if node.get_child(a) is None:
                    child = TunableDAMCTSNode(
                        parent=node,
                        action=a,
                        eps_net_func=node.eps_net_func,
                        env=env,
                        env_name=node.env_name,
                        action_dim=node._action_dim,
                        lo=node.lo,
                        hi=node.hi,
                        hyperparams=node.hyperparams
                    )
                    node.children[self._act_key(a)] = child
                    return child

            # If all actions exist, pick best child by UCB and continue
            best_child = None
            best_score = -float("inf")
            for child in node.children.values():
                score = node.ucb_score(child, node.visit_count)
                if score > best_score:
                    best_score = score
                    best_child = child

            if best_child is None:
                return node
            node = best_child

    def rollout(self, env, max_depth):
        if self.is_done:
            return 0.0

        env.load_snapshot(self.snapshot)

        total = 0.0
        disc = 1.0
        r_offset = self.hyperparams['reward_offset']
        reward_scaling = self.hyperparams['reward_scaling']
        a_dim = self.get_action_dim()

        for _ in range(max_depth):
            a = np.random.uniform(self.lo, self.hi, size=(a_dim,)).astype(np.float32)
            obs, r, done, _ = env.step(a)
            shaped = max(0.01, (r + r_offset) * reward_scaling)
            total += shaped * disc
            disc *= 0.99
            if done:
                break
        return total

    def backpropagate(self, rollout_return):
        total_return = self.immediate_reward + rollout_return

        self.visit_count += 1
        self.value_sum += total_return
        self.value_sum_power += total_return ** self.hyperparams['power']

        if not self.is_root():
            self.parent.backpropagate(rollout_return * 0.99)


class TunableDAMCTSRoot(TunableDAMCTSNode):
    def __init__(self, snapshot, obs, eps_net_func, env_name, action_dim, lo, hi, hyperparams):
        super().__init__(
            parent=None,
            action=None,
            eps_net_func=eps_net_func,
            env=None,
            env_name=env_name,
            action_dim=action_dim,
            lo=lo,
            hi=hi,
            hyperparams=hyperparams
        )
        self.snapshot = snapshot
        self.obs = obs
        self.immediate_reward = 0.0
        self.is_done = False


# ========================================
# Hyperparameter Search Space
# ========================================

# Define the search space for DAMCTS hyperparameters
search_space = [
    Real(5.0, 100.0, name='bonus_c', prior='log-uniform'),  # Exploration bonus coefficient
    Real(0.1, 0.5, name='parent_exp'),  # Parent visit exponent
    Real(0.3, 0.8, name='child_exp'),  # Child visit exponent
    Real(1.0, 4.0, name='power'),  # Power-mean backup exponent
    Real(0.1, 1.0, name='epsilon_1'),  # Discretization parameter
    Real(0.5, 2.0, name='beta'),  # Discretization parameter
    Real(0.5, 2.0, name='l_holder'),  # Hölder constant
    Real(0.5, 2.0, name='reward_scaling'),  # Reward scaling
    Real(5.0, 50.0, name='reward_offset'),  # Reward offset
    Integer(50, 200, name='max_rollout_depth'),  # Maximum rollout depth
]

# Default hyperparameters (your current values)
default_hyperparams = {
    'bonus_c': 30.0,
    'parent_exp': 0.25,
    'child_exp': 0.5,
    'power': 2.0,
    'epsilon_1': 0.5,
    'beta': 1.0,
    'l_holder': 1.0,
    'reward_scaling': 1.0,
    'reward_offset': 20.0,
    'max_rollout_depth': 100
}

# ========================================
# Environment Configuration
# ========================================

ENV_CONFIGS = {
    "ImprovedWalker2d-v0": {
        "action_dim": 6,
        "lo": -1.0,
        "hi": 1.0,
        "noise_config": {
            "action_noise_scale": 0.03,
            "dynamics_noise_scale": 0.02,
            "obs_noise_scale": 0.01,
        }
    },
    "ImprovedAnt-v0": {
        "action_dim": 8,
        "lo": -1.0,
        "hi": 1.0,
        "noise_config": {
            "action_noise_scale": 0.03,
            "dynamics_noise_scale": 0.02,
            "obs_noise_scale": 0.01,
        }
    },
    "ImprovedHumanoid-v0": {
        "action_dim": 17,
        "lo": -1.0,
        "hi": 1.0,
        "noise_config": {
            "action_noise_scale": 0.03,
            "dynamics_noise_scale": 0.02,
            "obs_noise_scale": 0.01,
        }
    }
}


# ========================================
# Evaluation Functions
# ========================================

def build_epsilon_net(env_name, action_dim, epsilon, lo=-1.0, hi=1.0):
    """Build epsilon-net for actions"""
    if action_dim <= 4:
        per_dim = int(np.ceil((1.0 / max(epsilon, 1e-3)) ** (1.0 / action_dim)))
        per_dim = int(np.clip(per_dim, 2, 25))
        axes = [np.linspace(lo, hi, per_dim, dtype=np.float32) for _ in range(action_dim)]
        mesh = np.meshgrid(*axes, indexing="ij")
        points = np.stack([m.ravel() for m in mesh], axis=-1)
        return [p.astype(np.float32, copy=False) for p in points]
    else:
        cap = 1500 if action_dim <= 8 else 2000
        n_samples = int(np.clip((1.0 / max(epsilon, 1e-3)) ** min(action_dim, 2), 100, cap))
        samples = np.random.uniform(lo, hi, size=(n_samples, action_dim)).astype(np.float32)
        return [s for s in samples]


def plan_tunable_damcts(root, env, n_iter, max_depth):
    """Run DAMCTS planning with tunable parameters"""
    for _ in range(n_iter):
        leaf = root.tree_policy(env)
        if leaf.is_done:
            leaf.backpropagate(0.0)
            continue
        rollout_value = leaf.rollout(env, max_depth)
        leaf.backpropagate(rollout_value)


def evaluate_hyperparams(hyperparams_dict, env_name, num_seeds=5, num_iterations=20, max_episode_steps=100):
    """
    Evaluate DAMCTS with given hyperparameters

    Args:
        hyperparams_dict: Dictionary of hyperparameter values
        env_name: Environment name to test on
        num_seeds: Number of random seeds to average over
        num_iterations: DAMCTS iterations per planning step
        max_episode_steps: Maximum steps per episode (shorter for faster evaluation)

    Returns:
        Average performance across seeds
    """

    env_config = ENV_CONFIGS[env_name]
    action_dim = env_config["action_dim"]
    lo, hi = env_config["lo"], env_config["hi"]
    noise_config = env_config["noise_config"]

    seed_returns = []

    for seed in range(num_seeds):
        try:
            # Set random seed
            random.seed(seed)
            np.random.seed(seed)

            # Create environments
            planning_env = SnapshotEnv(gym.make(env_name, **noise_config).env)
            test_env = SnapshotEnv(gym.make(env_name, **noise_config).env)

            # Initialize
            root_obs = test_env.reset()
            root_snapshot = test_env.get_snapshot()

            # Create epsilon net function
            def eps_net_func(epsilon):
                return build_epsilon_net(env_name, action_dim, epsilon, lo, hi)

            # Create root with hyperparameters
            root = TunableDAMCTSRoot(
                snapshot=root_snapshot,
                obs=root_obs,
                eps_net_func=eps_net_func,
                env_name=env_name,
                action_dim=action_dim,
                lo=lo,
                hi=hi,
                hyperparams=hyperparams_dict
            )

            # Run episode
            total_reward = 0.0
            discount = 1.0

            for step in range(max_episode_steps):
                # Plan with DAMCTS
                plan_tunable_damcts(root, planning_env, num_iterations, hyperparams_dict['max_rollout_depth'])

                # Select best action
                if len(root.children) == 0:
                    best_action = np.random.uniform(lo, hi, size=(action_dim,)).astype(np.float32)
                    best_child = None
                else:
                    best_child = max(root.children.values(), key=lambda c: c.get_power_mean_value())
                    best_action = best_child.action

                # Execute action
                obs, r, done, _ = test_env.step(best_action)
                total_reward += r * discount
                discount *= 0.99

                if done:
                    break

                # Re-root tree
                snap_now = test_env.get_snapshot()
                planning_env.load_snapshot(snap_now)

                if best_child is None:
                    root = TunableDAMCTSRoot(
                        snapshot=snap_now,
                        obs=obs,
                        eps_net_func=eps_net_func,
                        env_name=env_name,
                        action_dim=action_dim,
                        lo=lo,
                        hi=hi,
                        hyperparams=hyperparams_dict
                    )
                else:
                    # Convert child to new root
                    new_root = TunableDAMCTSRoot(
                        snapshot=best_child.snapshot,
                        obs=best_child.obs,
                        eps_net_func=best_child.eps_net_func,
                        env_name=best_child.env_name,
                        action_dim=best_child._action_dim,
                        lo=best_child.lo,
                        hi=best_child.hi,
                        hyperparams=best_child.hyperparams
                    )
                    new_root.children = best_child.children
                    new_root.visit_count = best_child.visit_count
                    new_root.value_sum = best_child.value_sum
                    new_root.value_sum_power = best_child.value_sum_power
                    new_root.is_done = best_child.is_done
                    root = new_root

            seed_returns.append(total_reward)

            # Clean up
            planning_env.close()
            test_env.close()

        except Exception as e:
            print(f"Error in seed {seed}: {e}")
            seed_returns.append(0.0)  # Penalty for failed runs

    # Return negative mean (since we want to maximize performance but optimizer minimizes)
    mean_performance = statistics.mean(seed_returns)
    return -mean_performance


# ========================================
# Bayesian Optimization
# ========================================

class DAMCTSBayesianOptimizer:
    """Bayesian optimizer for DAMCTS hyperparameters"""

    def __init__(self, target_env="ImprovedWalker2d-v0", n_calls=50, n_initial_points=10):
        self.target_env = target_env
        self.n_calls = n_calls
        self.n_initial_points = n_initial_points
        self.results = None
        self.best_hyperparams = None

        # Create objective function
        @use_named_args(search_space)
        def objective(**params):
            start_time = time.time()

            # Convert to hyperparams dict
            hyperparams = dict(params)

            # Evaluate performance
            performance = evaluate_hyperparams(
                hyperparams,
                self.target_env,
                num_seeds=3,  # Fewer seeds for faster optimization
                num_iterations=15,  # Fewer iterations for faster evaluation
                max_episode_steps=80  # Shorter episodes
            )

            elapsed = time.time() - start_time

            # Log progress
            print(f"Evaluation completed in {elapsed:.1f}s")
            print(f"Performance: {-performance:.3f}")  # Convert back to positive
            print(f"Hyperparams: {hyperparams}")
            print("-" * 50)

            return performance

        self.objective = objective

    def optimize(self):
        """Run Bayesian optimization"""

        print(f"Starting Bayesian Optimization for {self.target_env}")
        print(f"Search space: {len(search_space)} dimensions")
        print(f"Budget: {self.n_calls} evaluations")
        print("=" * 60)

        # Run optimization
        self.results = gp_minimize(
            func=self.objective,
            dimensions=search_space,
            n_calls=self.n_calls,
            n_initial_points=self.n_initial_points,
            acq_func='EI',  # Expected Improvement
            random_state=42
        )

        # Extract best hyperparameters
        best_params = self.results.x
        self.best_hyperparams = {
            name: value for name, value in zip([dim.name for dim in search_space], best_params)
        }

        print("\n" + "=" * 60)
        print("OPTIMIZATION COMPLETE!")
        print("=" * 60)
        print(f"Best performance: {-self.results.fun:.3f}")
        print(f"Found after {len(self.results.func_vals)} evaluations")
        print("\nBest hyperparameters:")
        for name, value in self.best_hyperparams.items():
            default_val = default_hyperparams.get(name, "N/A")
            print(f"  {name}: {value:.4f} (default: {default_val})")

        return self.best_hyperparams

    def evaluate_best(self, num_seeds=20, num_iterations=30):
        """Evaluate best hyperparameters with more thorough testing"""

        if self.best_hyperparams is None:
            print("No optimization results found. Run optimize() first.")
            return

        print(f"\nThorough evaluation of best hyperparameters...")
        print(f"Environment: {self.target_env}")
        print(f"Seeds: {num_seeds}, Iterations: {num_iterations}")

        # Evaluate best hyperparameters
        best_performance = -evaluate_hyperparams(
            self.best_hyperparams,
            self.target_env,
            num_seeds=num_seeds,
            num_iterations=num_iterations,
            max_episode_steps=150
        )

        # Evaluate default hyperparameters for comparison
        default_performance = -evaluate_hyperparams(
            default_hyperparams,
            self.target_env,
            num_seeds=num_seeds,
            num_iterations=num_iterations,
            max_episode_steps=150
        )

        improvement = best_performance - default_performance
        improvement_pct = (improvement / default_performance) * 100

        print(f"\nThorough Evaluation Results:")
        print(f"Best hyperparams:    {best_performance:.3f}")
        print(f"Default hyperparams: {default_performance:.3f}")
        print(f"Improvement:         {improvement:.3f} ({improvement_pct:.1f}%)")

        return best_performance, default_performance

    def save_results(self, filename=None):
        """Save optimization results"""

        if filename is None:
            filename = f"damcts_optimization_{self.target_env}_{int(time.time())}.pkl"

        results_data = {
            'target_env': self.target_env,
            'optimization_results': self.results,
            'best_hyperparams': self.best_hyperparams,
            'search_space': search_space,
            'default_hyperparams': default_hyperparams
        }

        with open(filename, 'wb') as f:
            pickle.dump(results_data, f)

        print(f"Results saved to {filename}")

    def plot_results(self):
        """Plot optimization results"""

        if self.results is None:
            print("No results to plot. Run optimize() first.")
            return

        try:
            # Plot convergence
            plt.figure(figsize=(12, 5))

            plt.subplot(1, 2, 1)
            plot_convergence(self.results)
            plt.title(f'Convergence Plot - {self.target_env}')

            plt.subplot(1, 2, 2)
            # Plot performance over iterations
            performances = [-val for val in self.results.func_vals]  # Convert to positive
            plt.plot(performances, 'b-', linewidth=2)
            plt.axhline(y=max(performances), color='r', linestyle='--', alpha=0.7)
            plt.xlabel('Iteration')
            plt.ylabel('Performance')
            plt.title('Performance over Iterations')
            plt.grid(True, alpha=0.3)

            plt.tight_layout()
            plt.savefig(f'damcts_optimization_results_{self.target_env}.png', dpi=300)
            plt.show()

        except Exception as e:
            print(f"Plotting failed: {e}")


# ========================================
# Multi-Environment Optimization
# ========================================

def optimize_multiple_environments(environments, n_calls_per_env=30):
    """Optimize hyperparameters across multiple environments"""

    results = {}

    for env_name in environments:
        print(f"\n{'=' * 80}")
        print(f"OPTIMIZING FOR {env_name}")
        print(f"{'=' * 80}")

        optimizer = DAMCTSBayesianOptimizer(
            target_env=env_name,
            n_calls=n_calls_per_env,
            n_initial_points=max(5, n_calls_per_env // 5)
        )

        # Run optimization
        best_params = optimizer.optimize()
        results[env_name] = {
            'best_hyperparams': best_params,
            'optimizer': optimizer
        }

        # Save individual results
        optimizer.save_results(f"damcts_opt_{env_name}.pkl")

        # Quick evaluation
        optimizer.evaluate_best(num_seeds=5, num_iterations=20)

    # Summary across environments
    print(f"\n{'=' * 80}")
    print("MULTI-ENVIRONMENT OPTIMIZATION SUMMARY")
    print(f"{'=' * 80}")

    for env_name, result in results.items():
        print(f"\n{env_name}:")
        best_params = result['best_hyperparams']
        for param_name, value in best_params.items():
            default_val = default_hyperparams.get(param_name, "N/A")
            print(f"  {param_name}: {value:.4f} (default: {default_val})")

    return results


# ========================================
# Main Execution
# ========================================

if __name__ == "__main__":
    print("DAMCTS Bayesian Hyperparameter Optimization")
    print("=" * 60)
    print("Choose an option:")
    print("1. Optimize for single environment (Walker2d)")
    print("2. Optimize for single environment (Humanoid)")
    print("3. Optimize for single environment (Ant)")
    print("4. Optimize for multiple environments")
    print("5. Quick test with default hyperparameters")

    choice = input("Enter choice (1-5): ").strip()

    if choice == "1":
        # Single environment - Walker2d
        optimizer = DAMCTSBayesianOptimizer("ImprovedWalker2d-v0", n_calls=40)
        best_params = optimizer.optimize()
        optimizer.evaluate_best()
        optimizer.save_results()
        optimizer.plot_results()

    elif choice == "2":
        # Single environment - Humanoid
        optimizer = DAMCTSBayesianOptimizer("ImprovedHumanoid-v0", n_calls=40)
        best_params = optimizer.optimize()
        optimizer.evaluate_best()
        optimizer.save_results()
        optimizer.plot_results()

    elif choice == "3":
        # Single environment - Ant
        optimizer = DAMCTSBayesianOptimizer("ImprovedAnt-v0", n_calls=40)
        best_params = optimizer.optimize()
        optimizer.evaluate_best()
        optimizer.save_results()
        optimizer.plot_results()

    elif choice == "4":
        # Multiple environments
        environments = ["ImprovedWalker2d-v0", "ImprovedAnt-v0", "ImprovedHumanoid-v0"]
        results = optimize_multiple_environments(environments, n_calls_per_env=25)

    elif choice == "5":
        # Quick test
        print("Testing default hyperparameters...")
        performance = -evaluate_hyperparams(
            default_hyperparams,
            "ImprovedWalker2d-v0",
            num_seeds=3,
            num_iterations=15,
            max_episode_steps=50
        )
        print(f"Default performance on Walker2d: {performance:.3f}")

    else:
        print("Invalid choice. Running quick test...")
        performance = -evaluate_hyperparams(
            default_hyperparams,
            "ImprovedWalker2d-v0",
            num_seeds=3,
            num_iterations=10
        )
        print(f"Default performance: {performance:.3f}")

# Usage instructions:
"""
Installation:
    pip install scikit-optimize matplotlib

Quick start:
    python damcts_bayesian_optimization.py

Expected runtime:
    - Single environment (40 evaluations): ~2-4 hours
    - Multiple environments: ~6-12 hours
    - Quick test: ~2-5 minutes

The optimizer will:
1. Search over 10 key DAMCTS hyperparameters
2. Use Gaussian Process regression for efficient search
3. Save results and generate plots
4. Provide thorough evaluation of best parameters found
"""